from CML_MMD2.core.kernel import get_kernel, optimize_kernel_binsearch_only
from CML_MMD2.core.mmd import mmd_neg_biased_batched


def get_MMD2_values(D_Xs, D_Ys, V_X, V_Y, kernel, device, batch_size=256):
    results = [mmd_neg_biased_batched(D_X, V_X, kernel, device, batch_size) for D_X in D_Xs]
    return [neg_mmd.item() for (neg_mmd, S_X, S_Y) in results]


def get_extracted(model, loader, device):
    model = model.to(device)
    D_X = []
    model.eval()
    with torch.no_grad():
        for i, (batch_data, batch_target) in enumerate(loader):
            batch_data, batch_target = batch_data.to(device), batch_target.to(device)
            outputs = model.extract_fc2(batch_data)

            D_X.append(outputs)

    return torch.cat(D_X).detach().cpu().numpy()

from copy import deepcopy

from utils import cwd, set_deterministic, save_results

from data_utils import _get_loader
from utils import get_trained_feature_extractor, get_accuracy
import numpy as np
import torch

from reg_data_utils import  _get_CaliH, _get_KingH, _get_FaceA, _get_census, huber_regression, assign_data
from os.path import join as oj

from tqdm import tqdm
import argparse

baseline = 'MMD_sq'

if __name__ == '__main__':
    

    print(f"----- Running experiment for {baseline} -----")

    parser = argparse.ArgumentParser(description='Process which dataset to run for regression.')
    parser.add_argument('-N', '--N', help='Number if data venrods.', type=int, required=True, default=5)
    parser.add_argument('-m', '--size', help='Size of sample datasets.', type=int, required=True, default=1500)
    parser.add_argument('-P', '--dataset', help='Pick the dataset to run.', type=str, required=True)
    parser.add_argument('-Q', '--Q_dataset', help='Pick the Q dataset.', type=str, required=True, choices=['KingH', 'Census17'])
    parser.add_argument('-n_t', '--n_trials', help='Number of trials.', type=int, default=5)

    # parser.add_argument('-nocuda', dest='cuda', help='Not to use cuda even if available.', action='store_false')
    # parser.add_argument('-cuda', dest='cuda', help='Use cuda if available.', action='store_true')

    cmd_args = parser.parse_args()
    print(cmd_args)

    dataset = cmd_args.dataset
    Q_dataset = cmd_args.Q_dataset
    N = cmd_args.N
    size = cmd_args.size
    n_trials = cmd_args.n_trials

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    set_deterministic()

    optimize_kernel_params = False
    
    kernel_name = 'se'
    values_over_trials, values_hat_over_trials =[], []
    for _ in tqdm(range(n_trials), desc =f'A total of {n_trials} trials.'):
        # raw data
        D_Xs, D_Ys, V_X, V_Y = assign_data(N, size, dataset, Q_dataset)

        d = D_Xs[0].shape[1]

        kernel = get_kernel(kernel_name, d, 1., device)
        if optimize_kernel_params:
            print("Optimizing kernel parameters")
            kernel, lengthscale = optimize_kernel_binsearch_only(kernel, device, torch.stack(D_Xs), V_X)

        MMD2_values = get_MMD2_values(D_Xs, None, V_X, None, kernel, device)
        values_over_trials.append(MMD2_values)

        MMD2_values_hat = get_MMD2_values(D_Xs, None, torch.cat(D_Xs), None, kernel, device)
        values_hat_over_trials.append(MMD2_values_hat)

    results = {'values_over_trials':values_over_trials, 'values_hat_over_trials': values_hat_over_trials, 'N':N, 'size':size, 'n_trials': n_trials,
    'd':d, 'kernel':kernel}
    save_results(baseline=baseline, exp_name=oj('regression', f'{dataset}_vs_{Q_dataset}-N{N} m{size} n_trials{n_trials}'), **results)
